Your Name: Roshan Sundar
For this assignment, you will do an ablation study on the DCGAN model discussed in class and implemented WGAN with weight clipping and (optional) WGAN with gradient penalty.
An ablation study measures performance changes after changing certain components in the AI system. The goal is to understand the contribution on each component for the overall system.
Here is the copy of the code implementation from course website. Please run the code to obtain the result and use it as a baseline to compare the results with the following the ablation tasks.
from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.cuda.manual_seed(manualSeed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmarks = False
os.environ['PYTHONHASHSEED'] = str(manualSeed)
# Root directory for dataset
# dataroot = "data/celeba"
# Number of workers for dataloader
workers = 1
# Batch size during training
batch_size = 128
# Spatial size of training images. All images will be resized to this
# size using a transformer.
#image_size = 64
image_size = 32
# Number of channels in the training images. For color images this is 3
#nc = 3
nc = 1
# Size of z latent vector (i.e. size of generator input)
nz = 100
# Size of feature maps in generator
#ngf = 64
ngf = 8
# Size of feature maps in discriminator
#ndf = 64
ndf = 8
# Number of training epochs
num_epochs = 5
num_epochs_wgan = 15
num_iters = 250
# Learning rate for optimizers
lr = 0.0002
lr_rms = 5e-4
# Beta1 hyperparam for Adam optimizers
beta1 = 0.5
# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
# Initialize BCELoss function
criterion = nn.BCELoss()
# Create batch of latent vectors that we will use to visualize
# the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)
# Establish convention for real and fake labels during training
real_label = 1.0
fake_label = 0.0
# Several useful functions
def initialize_net(net_class, init_method, device, ngpu):
# Create the generator
net_inst = net_class(ngpu).to(device)
# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
net_inst = nn.DataParallel(net_inst, list(range(ngpu)))
# Apply the weights_init function to randomly initialize all weights
# to mean=0, stdev=0.2.
if init_method is not None:
net_inst.apply(init_method)
# Print the model
print(net_inst)
return net_inst
def plot_GAN_loss(losses, labels):
plt.figure(figsize=(10,5))
plt.title("Losses During Training")
for loss, label in zip(losses, labels):
plt.plot(loss,label=f"{label}")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
def plot_real_fake_images(real_batch, fake_batch):
# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))
# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(fake_batch[-1],(1,2,0)))
plt.show()
# custom weights initialization called on netG and netD
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
# Download the MNIST dataset
dataset = dset.MNIST(
'data', train=True, download=True,
transform=transforms.Compose([
transforms.Resize(image_size), # Resize from 28 x 28 to 32 x 32 (so power of 2)
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=True, num_workers=workers)
# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
Random Seed: 999 Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz
100%|██████████| 9912422/9912422 [00:00<00:00, 83598843.50it/s]
Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz
100%|██████████| 28881/28881 [00:00<00:00, 48261232.60it/s]
Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz
100%|██████████| 1648877/1648877 [00:00<00:00, 55005459.25it/s]
Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz
100%|██████████| 4542/4542 [00:00<00:00, 20159289.70it/s]
Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw
<matplotlib.image.AxesImage at 0x7fb18472ffa0>
# Generator Code
class Generator(nn.Module):
def __init__(self, ngpu):
super(Generator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is Z, going into a convolution, state size. nz x 1 x 1
nn.ConvTranspose2d( nz, ngf * 4, kernel_size=4, stride=1, padding=0, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True), # inplace ReLU
# current state size. (ngf*4) x 4 x 4
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# current state size. (ngf*2) x 8 x 8
nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# current state size. ngf x 16 x 16
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
# current state size. nc x 32 x 32
# Produce number between -1 and 1, as pixel values have been normalized to be between -1 and 1
nn.Tanh()
)
def forward(self, input):
return self.main(input)
class Discriminator(nn.Module):
def __init__(self, ngpu):
super(Discriminator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is (nc) x 32 x 32
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 16 x 16
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 8 x 8
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 4 x 4
nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
# state size. (ndf*4) x 1 x 1
nn.Sigmoid() # Produce probability
)
def forward(self, input):
return self.main(input)
# Initialize networks
netG = initialize_net(Generator, weights_init, device, ngpu)
netD = initialize_net(Discriminator, weights_init, device, ngpu)
# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
# Training Loop
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0
print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
# For each batch in the dataloader
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
## Train with all-real batch
netD.zero_grad()
# Format batch
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size,), real_label, device=device)
# Forward pass real batch through D
output = netD(real_cpu).view(-1)
# Calculate loss on all-real batch
errD_real = criterion(output, label)
# Calculate gradients for D in backward pass
errD_real.backward()
D_x = output.mean().item()
## Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size, nz, 1, 1, device=device)
# Generate fake image batch with G
fake = netG(noise)
label.fill_(fake_label)
# Classify all fake batch with D
output = netD(fake.detach()).view(-1)
# Calculate D's loss on the all-fake batch
errD_fake = criterion(output, label)
# Calculate the gradients for this batch
errD_fake.backward()
D_G_z1 = output.mean().item()
# Add the gradients from the all-real and all-fake batches
errD = errD_real + errD_fake
# Update D
optimizerD.step()
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of all-fake batch through D
output = netD(fake).view(-1)
# Calculate G's loss based on this output
errG = criterion(output, label)
# Calculate gradients for G
errG.backward()
D_G_z2 = output.mean().item()
# Update G
optimizerG.step()
# Output training stats
if i % 50 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
% (epoch, num_epochs, i, len(dataloader),
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
# Save Losses for plotting later
G_losses.append(errG.item())
D_losses.append(errD.item())
# Check how the generator is doing by saving G's output on fixed_noise
if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
iters += 1
Generator(
(main): Sequential(
(0): ConvTranspose2d(100, 32, kernel_size=(4, 4), stride=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(7): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU(inplace=True)
(9): ConvTranspose2d(8, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(10): Tanh()
)
)
Discriminator(
(main): Sequential(
(0): Conv2d(1, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2, inplace=True)
(5): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): LeakyReLU(negative_slope=0.2, inplace=True)
(8): Conv2d(32, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
(9): Sigmoid()
)
)
Starting Training Loop...
[0/5][0/469] Loss_D: 1.4365 Loss_G: 0.7517 D(x): 0.4824 D(G(z)): 0.5008 / 0.4750
[0/5][50/469] Loss_D: 0.5903 Loss_G: 1.3937 D(x): 0.8008 D(G(z)): 0.3003 / 0.2546
[0/5][100/469] Loss_D: 0.3073 Loss_G: 2.1208 D(x): 0.8771 D(G(z)): 0.1559 / 0.1266
[0/5][150/469] Loss_D: 0.1544 Loss_G: 2.6181 D(x): 0.9450 D(G(z)): 0.0913 / 0.0820
[0/5][200/469] Loss_D: 0.0581 Loss_G: 3.6729 D(x): 0.9750 D(G(z)): 0.0318 / 0.0298
[0/5][250/469] Loss_D: 0.0407 Loss_G: 4.0361 D(x): 0.9811 D(G(z)): 0.0211 / 0.0197
[0/5][300/469] Loss_D: 0.0252 Loss_G: 4.5581 D(x): 0.9889 D(G(z)): 0.0138 / 0.0111
[0/5][350/469] Loss_D: 0.0203 Loss_G: 4.7004 D(x): 0.9906 D(G(z)): 0.0107 / 0.0107
[0/5][400/469] Loss_D: 0.0218 Loss_G: 5.2893 D(x): 0.9917 D(G(z)): 0.0132 / 0.0060
[0/5][450/469] Loss_D: 0.0674 Loss_G: 3.8648 D(x): 0.9650 D(G(z)): 0.0306 / 0.0262
[1/5][0/469] Loss_D: 0.0479 Loss_G: 4.0789 D(x): 0.9730 D(G(z)): 0.0200 / 0.0197
[1/5][50/469] Loss_D: 0.2738 Loss_G: 8.7382 D(x): 0.9933 D(G(z)): 0.2230 / 0.0002
[1/5][100/469] Loss_D: 0.0595 Loss_G: 4.0230 D(x): 0.9713 D(G(z)): 0.0290 / 0.0235
[1/5][150/469] Loss_D: 0.0559 Loss_G: 3.8343 D(x): 0.9662 D(G(z)): 0.0207 / 0.0261
[1/5][200/469] Loss_D: 0.0792 Loss_G: 3.6427 D(x): 0.9641 D(G(z)): 0.0409 / 0.0317
[1/5][250/469] Loss_D: 0.0777 Loss_G: 3.8496 D(x): 0.9671 D(G(z)): 0.0424 / 0.0268
[1/5][300/469] Loss_D: 0.0907 Loss_G: 3.7697 D(x): 0.9533 D(G(z)): 0.0408 / 0.0288
[1/5][350/469] Loss_D: 0.8933 Loss_G: 1.4637 D(x): 0.5490 D(G(z)): 0.1691 / 0.2841
[1/5][400/469] Loss_D: 0.1584 Loss_G: 3.8905 D(x): 0.9555 D(G(z)): 0.1030 / 0.0254
[1/5][450/469] Loss_D: 0.2203 Loss_G: 3.0408 D(x): 0.9425 D(G(z)): 0.1404 / 0.0584
[2/5][0/469] Loss_D: 1.1430 Loss_G: 0.1740 D(x): 0.3794 D(G(z)): 0.0053 / 0.8461
[2/5][50/469] Loss_D: 0.3015 Loss_G: 1.7689 D(x): 0.7916 D(G(z)): 0.0515 / 0.1978
[2/5][100/469] Loss_D: 0.2152 Loss_G: 2.7790 D(x): 0.9259 D(G(z)): 0.1240 / 0.0731
[2/5][150/469] Loss_D: 0.2885 Loss_G: 2.7082 D(x): 0.9502 D(G(z)): 0.2032 / 0.0785
[2/5][200/469] Loss_D: 0.2895 Loss_G: 2.0574 D(x): 0.8475 D(G(z)): 0.1050 / 0.1472
[2/5][250/469] Loss_D: 0.3635 Loss_G: 3.0346 D(x): 0.9434 D(G(z)): 0.2548 / 0.0553
[2/5][300/469] Loss_D: 0.2832 Loss_G: 2.0253 D(x): 0.8614 D(G(z)): 0.1184 / 0.1476
[2/5][350/469] Loss_D: 0.3049 Loss_G: 2.5284 D(x): 0.8566 D(G(z)): 0.1308 / 0.0929
[2/5][400/469] Loss_D: 1.0507 Loss_G: 0.3886 D(x): 0.4049 D(G(z)): 0.0149 / 0.6924
[2/5][450/469] Loss_D: 0.3168 Loss_G: 2.1532 D(x): 0.8296 D(G(z)): 0.1110 / 0.1365
[3/5][0/469] Loss_D: 0.5010 Loss_G: 0.9333 D(x): 0.6536 D(G(z)): 0.0417 / 0.4269
[3/5][50/469] Loss_D: 0.5283 Loss_G: 3.0217 D(x): 0.9592 D(G(z)): 0.3636 / 0.0608
[3/5][100/469] Loss_D: 0.4765 Loss_G: 2.8091 D(x): 0.9311 D(G(z)): 0.3184 / 0.0700
[3/5][150/469] Loss_D: 0.4758 Loss_G: 2.9791 D(x): 0.9065 D(G(z)): 0.2969 / 0.0627
[3/5][200/469] Loss_D: 0.3692 Loss_G: 2.1616 D(x): 0.8159 D(G(z)): 0.1423 / 0.1310
[3/5][250/469] Loss_D: 0.5128 Loss_G: 3.4594 D(x): 0.9268 D(G(z)): 0.3334 / 0.0400
[3/5][300/469] Loss_D: 0.5232 Loss_G: 2.4309 D(x): 0.9159 D(G(z)): 0.3338 / 0.1045
[3/5][350/469] Loss_D: 0.4127 Loss_G: 1.7225 D(x): 0.8076 D(G(z)): 0.1656 / 0.2019
[3/5][400/469] Loss_D: 0.4305 Loss_G: 2.1135 D(x): 0.8328 D(G(z)): 0.2017 / 0.1468
[3/5][450/469] Loss_D: 0.7505 Loss_G: 2.5221 D(x): 0.9507 D(G(z)): 0.4778 / 0.0965
[4/5][0/469] Loss_D: 0.4592 Loss_G: 2.6408 D(x): 0.8447 D(G(z)): 0.2386 / 0.0837
[4/5][50/469] Loss_D: 0.8086 Loss_G: 3.1988 D(x): 0.9124 D(G(z)): 0.4908 / 0.0480
[4/5][100/469] Loss_D: 0.5029 Loss_G: 1.3491 D(x): 0.7026 D(G(z)): 0.1133 / 0.2842
[4/5][150/469] Loss_D: 0.4115 Loss_G: 2.0351 D(x): 0.8405 D(G(z)): 0.1955 / 0.1500
[4/5][200/469] Loss_D: 0.4738 Loss_G: 2.3371 D(x): 0.8667 D(G(z)): 0.2622 / 0.1149
[4/5][250/469] Loss_D: 0.6940 Loss_G: 2.4102 D(x): 0.9381 D(G(z)): 0.4424 / 0.1082
[4/5][300/469] Loss_D: 0.5978 Loss_G: 1.9517 D(x): 0.8353 D(G(z)): 0.3174 / 0.1689
[4/5][350/469] Loss_D: 0.4256 Loss_G: 1.9415 D(x): 0.8672 D(G(z)): 0.2332 / 0.1613
[4/5][400/469] Loss_D: 0.8483 Loss_G: 2.2737 D(x): 0.9300 D(G(z)): 0.5095 / 0.1225
[4/5][450/469] Loss_D: 0.5946 Loss_G: 1.7913 D(x): 0.8973 D(G(z)): 0.3596 / 0.1942
# plot the loss for generator and discriminator
plot_GAN_loss([G_losses, D_losses], ["G", "D"])
# Grab a batch of real images from the dataloader
plot_real_fake_images(next(iter(dataloader)), img_list)
# Generator Code
class Generator_woBN(nn.Module):
def __init__(self, ngpu):
super(Generator_woBN, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
################################ YOUR CODE ################################
# input is Z, going into a convolution, state size. nz x 1 x 1
nn.ConvTranspose2d(nz, ngf * 4, kernel_size=4, stride=1, padding=0, bias=False),
nn.ReLU(True), # inplace ReLU
# current state size. (ngf*4) x 4 x 4
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.ReLU(True),
# current state size. (ngf*2) x 8 x 8
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.ReLU(True),
# current state size. ngf x 16 x 16
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
# current state size. nc x 32 x 32
# Produce number between -1 and 1, as pixel values have been normalized to be between -1 and 1
nn.Tanh()
############################# END YOUR CODE ##############################
)
def forward(self, input):
return self.main(input)
class Discriminator_woBN(nn.Module):
def __init__(self, ngpu):
super(Discriminator_woBN, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
################################ YOUR CODE ################################
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 16 x 16
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 8 x 8
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 4 x 4
nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
# state size. (ndf*4) x 1 x 1
nn.Sigmoid() # Produce probability
############################# END YOUR CODE ##############################
)
def forward(self, input):
return self.main(input)
netG_noBN = initialize_net(Generator_woBN, weights_init, device, ngpu)
netD_noBN = initialize_net(Discriminator_woBN, weights_init, device, ngpu)
Generator_woBN(
(main): Sequential(
(0): ConvTranspose2d(100, 32, kernel_size=(4, 4), stride=(1, 1), bias=False)
(1): ReLU(inplace=True)
(2): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): ReLU(inplace=True)
(4): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(5): ReLU(inplace=True)
(6): ConvTranspose2d(8, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(7): Tanh()
)
)
Discriminator_woBN(
(main): Sequential(
(0): Conv2d(1, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): LeakyReLU(negative_slope=0.2, inplace=True)
(4): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(5): LeakyReLU(negative_slope=0.2, inplace=True)
(6): Conv2d(32, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
(7): Sigmoid()
)
)
# Setup Adam optimizers for both G and D
optimizerD_noBN = optim.Adam(netD_noBN.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG_noBN = optim.Adam(netG_noBN.parameters(), lr=lr, betas=(beta1, 0.999))
# Training Loop
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0
print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
# For each batch in the dataloader
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
## Train with all-real batch
netD_noBN.zero_grad()
# Format batch
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size,), real_label, device=device)
# Forward pass real batch through D
output = netD_noBN(real_cpu).view(-1)
# Calculate loss on all-real batch
errD_real = criterion(output, label)
# Calculate gradients for D in backward pass
errD_real.backward()
D_x = output.mean().item()
## Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size, nz, 1, 1, device=device)
# Generate fake image batch with G
fake = netG_noBN(noise)
label.fill_(fake_label)
# Classify all fake batch with D
output = netD_noBN(fake.detach()).view(-1)
# Calculate D's loss on the all-fake batch
errD_fake = criterion(output, label)
# Calculate the gradients for this batch
errD_fake.backward()
D_G_z1 = output.mean().item()
# Add the gradients from the all-real and all-fake batches
errD = errD_real + errD_fake
# Update D
optimizerD_noBN.step()
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG_noBN.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of all-fake batch through D
output = netD_noBN(fake).view(-1)
# Calculate G's loss based on this output
errG = criterion(output, label)
# Calculate gradients for G
errG.backward()
D_G_z2 = output.mean().item()
# Update G
optimizerG_noBN.step()
# Output training stats
if i % 50 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
% (epoch, num_epochs, i, len(dataloader),
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
# Save Losses for plotting later
G_losses.append(errG.item())
D_losses.append(errD.item())
# Check how the generator is doing by saving G's output on fixed_noise
if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
with torch.no_grad():
fake = netG_noBN(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
iters += 1
Starting Training Loop... [0/5][0/469] Loss_D: 1.3862 Loss_G: 0.6931 D(x): 0.5001 D(G(z)): 0.5000 / 0.5000 [0/5][50/469] Loss_D: 1.1975 Loss_G: 0.4640 D(x): 0.8304 D(G(z)): 0.6358 / 0.6289 [0/5][100/469] Loss_D: 0.7978 Loss_G: 1.0001 D(x): 0.7317 D(G(z)): 0.3819 / 0.3683 [0/5][150/469] Loss_D: 0.7587 Loss_G: 0.6820 D(x): 0.9521 D(G(z)): 0.5048 / 0.5060 [0/5][200/469] Loss_D: 1.0134 Loss_G: 0.5188 D(x): 0.9206 D(G(z)): 0.6011 / 0.5956 [0/5][250/469] Loss_D: 1.2352 Loss_G: 0.4831 D(x): 0.7766 D(G(z)): 0.6191 / 0.6171 [0/5][300/469] Loss_D: 0.9647 Loss_G: 0.6891 D(x): 0.7829 D(G(z)): 0.4992 / 0.5022 [0/5][350/469] Loss_D: 0.6815 Loss_G: 0.8246 D(x): 0.9121 D(G(z)): 0.4386 / 0.4389 [0/5][400/469] Loss_D: 0.6636 Loss_G: 0.7543 D(x): 0.9735 D(G(z)): 0.4702 / 0.4705 [0/5][450/469] Loss_D: 0.6894 Loss_G: 0.7112 D(x): 0.9874 D(G(z)): 0.4915 / 0.4912 [1/5][0/469] Loss_D: 0.7217 Loss_G: 0.6796 D(x): 0.9879 D(G(z)): 0.5078 / 0.5069 [1/5][50/469] Loss_D: 0.7208 Loss_G: 0.6978 D(x): 0.9716 D(G(z)): 0.4970 / 0.4977 [1/5][100/469] Loss_D: 0.6932 Loss_G: 0.7123 D(x): 0.9869 D(G(z)): 0.4923 / 0.4905 [1/5][150/469] Loss_D: 0.6925 Loss_G: 0.7006 D(x): 0.9957 D(G(z)): 0.4974 / 0.4963 [1/5][200/469] Loss_D: 0.7088 Loss_G: 0.6833 D(x): 0.9949 D(G(z)): 0.5050 / 0.5050 [1/5][250/469] Loss_D: 0.7178 Loss_G: 0.6713 D(x): 0.9980 D(G(z)): 0.5112 / 0.5110 [1/5][300/469] Loss_D: 0.7932 Loss_G: 0.6064 D(x): 0.9972 D(G(z)): 0.5462 / 0.5454 [1/5][350/469] Loss_D: 0.8464 Loss_G: 0.5747 D(x): 0.9913 D(G(z)): 0.5671 / 0.5629 [1/5][400/469] Loss_D: 0.7987 Loss_G: 0.6281 D(x): 0.9768 D(G(z)): 0.5391 / 0.5337 [1/5][450/469] Loss_D: 0.7648 Loss_G: 0.6776 D(x): 0.9526 D(G(z)): 0.5090 / 0.5079 [2/5][0/469] Loss_D: 0.7771 Loss_G: 0.6555 D(x): 0.9590 D(G(z)): 0.5195 / 0.5192 [2/5][50/469] Loss_D: 0.8526 Loss_G: 0.5935 D(x): 0.9574 D(G(z)): 0.5541 / 0.5525 [2/5][100/469] Loss_D: 0.8550 Loss_G: 0.5936 D(x): 0.9606 D(G(z)): 0.5569 / 0.5524 [2/5][150/469] Loss_D: 0.8686 Loss_G: 0.6420 D(x): 0.9240 D(G(z)): 0.5441 / 0.5270 [2/5][200/469] Loss_D: 0.7653 Loss_G: 0.6712 D(x): 0.9675 D(G(z)): 0.5184 / 0.5113 [2/5][250/469] Loss_D: 0.8039 Loss_G: 0.6414 D(x): 0.9527 D(G(z)): 0.5286 / 0.5269 [2/5][300/469] Loss_D: 0.8590 Loss_G: 0.6075 D(x): 0.9403 D(G(z)): 0.5482 / 0.5451 [2/5][350/469] Loss_D: 0.8569 Loss_G: 0.6128 D(x): 0.9533 D(G(z)): 0.5534 / 0.5423 [2/5][400/469] Loss_D: 0.9811 Loss_G: 0.5574 D(x): 0.9048 D(G(z)): 0.5792 / 0.5742 [2/5][450/469] Loss_D: 1.1287 Loss_G: 0.4944 D(x): 0.8440 D(G(z)): 0.6107 / 0.6120 [3/5][0/469] Loss_D: 1.0074 Loss_G: 0.5598 D(x): 0.8885 D(G(z)): 0.5811 / 0.5745 [3/5][50/469] Loss_D: 0.8605 Loss_G: 0.6677 D(x): 0.8894 D(G(z)): 0.5175 / 0.5136 [3/5][100/469] Loss_D: 0.9378 Loss_G: 0.6283 D(x): 0.8659 D(G(z)): 0.5392 / 0.5363 [3/5][150/469] Loss_D: 1.0784 Loss_G: 0.5549 D(x): 0.8788 D(G(z)): 0.6028 / 0.5766 [3/5][200/469] Loss_D: 0.9861 Loss_G: 0.5974 D(x): 0.8488 D(G(z)): 0.5546 / 0.5520 [3/5][250/469] Loss_D: 0.7779 Loss_G: 0.7662 D(x): 0.8968 D(G(z)): 0.4846 / 0.4659 [3/5][300/469] Loss_D: 0.9201 Loss_G: 0.7063 D(x): 0.9144 D(G(z)): 0.5607 / 0.4945 [3/5][350/469] Loss_D: 0.6878 Loss_G: 0.8877 D(x): 0.9105 D(G(z)): 0.4377 / 0.4159 [3/5][400/469] Loss_D: 1.0050 Loss_G: 0.5973 D(x): 0.8354 D(G(z)): 0.5488 / 0.5527 [3/5][450/469] Loss_D: 1.0070 Loss_G: 0.5646 D(x): 0.8747 D(G(z)): 0.5766 / 0.5705 [4/5][0/469] Loss_D: 0.6793 Loss_G: 0.8218 D(x): 0.9224 D(G(z)): 0.4470 / 0.4420 [4/5][50/469] Loss_D: 0.8761 Loss_G: 0.7617 D(x): 0.7194 D(G(z)): 0.3799 / 0.4678 [4/5][100/469] Loss_D: 0.7049 Loss_G: 0.8477 D(x): 0.8571 D(G(z)): 0.4194 / 0.4288 [4/5][150/469] Loss_D: 0.6927 Loss_G: 0.7637 D(x): 0.9449 D(G(z)): 0.4680 / 0.4672 [4/5][200/469] Loss_D: 0.6712 Loss_G: 0.7609 D(x): 0.9715 D(G(z)): 0.4729 / 0.4679 [4/5][250/469] Loss_D: 0.8869 Loss_G: 0.6227 D(x): 0.9268 D(G(z)): 0.5523 / 0.5373 [4/5][300/469] Loss_D: 0.8828 Loss_G: 0.6395 D(x): 0.8723 D(G(z)): 0.5169 / 0.5284 [4/5][350/469] Loss_D: 0.7383 Loss_G: 0.7631 D(x): 0.9129 D(G(z)): 0.4707 / 0.4678 [4/5][400/469] Loss_D: 0.6556 Loss_G: 0.8087 D(x): 0.9466 D(G(z)): 0.4505 / 0.4459 [4/5][450/469] Loss_D: 0.7976 Loss_G: 0.7033 D(x): 0.9356 D(G(z)): 0.5169 / 0.4955
# plot the loss for generator and discriminator
plot_GAN_loss([G_losses, D_losses], ["G", "D"])
# Grab a batch of real images from the dataloader
plot_real_fake_images(next(iter(dataloader)), img_list)
# re-initilizate networks for the generator and discrimintor.
netG = initialize_net(Generator, weights_init, device, ngpu)
netD = initialize_net(Discriminator, weights_init, device, ngpu)
# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
# Training Loop
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0
print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
# For each batch in the dataloader
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
################################ YOUR CODE ################################
## Train with all-real and all-fake batches concatenated
netD.zero_grad()
# Format batch
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size,), real_label, device=device)
# Generate batch of latent vectors
noise = torch.randn(b_size, nz, 1, 1, device=device)
# Generate fake image batch with G
fake = netG(noise)
# Concatenate real and fake batches along dimension 0
mixed_data = torch.cat((real_cpu, fake.detach()), 0)
mixed_labels = torch.cat((label, torch.full((b_size,), fake_label, device=device)), 0)
# Forward pass mixed batch through D
output = netD(mixed_data).view(-1)
# Calculate loss on mixed batch
errD = criterion(output, mixed_labels)
# Calculate gradients for D in backward pass
errD.backward()
D_x = output[:b_size].mean().item() # Average of real part
D_G_z1 = output[b_size:].mean().item() # Average of fake part
# Update D
optimizerD.step()
############################ END YOUR CODE ##############################
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
label = torch.full((b_size,), real_label, device=device) # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of all-fake batch through D
output = netD(fake).view(-1)
# Calculate G's loss based on this output
errG = criterion(output, label)
# Calculate gradients for G
errG.backward()
D_G_z2 = output.mean().item()
# Update G
optimizerG.step()
# Output training stats
if i % 50 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(G(z)): %.4f'
% (epoch, num_epochs, i, len(dataloader),
errD.item(), errG.item(), D_G_z2))
# Save Losses for plotting later
G_losses.append(errG.item())
D_losses.append(errD.item())
# Check how the generator is doing by saving G's output on fixed_noise
if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
iters += 1
Generator(
(main): Sequential(
(0): ConvTranspose2d(100, 32, kernel_size=(4, 4), stride=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(7): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU(inplace=True)
(9): ConvTranspose2d(8, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(10): Tanh()
)
)
Discriminator(
(main): Sequential(
(0): Conv2d(1, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2, inplace=True)
(5): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): LeakyReLU(negative_slope=0.2, inplace=True)
(8): Conv2d(32, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
(9): Sigmoid()
)
)
Starting Training Loop...
[0/5][0/469] Loss_D: 0.7306 Loss_G: 0.7855 D(G(z)): 0.4597
[0/5][50/469] Loss_D: 0.0267 Loss_G: 0.1076 D(G(z)): 0.8983
[0/5][100/469] Loss_D: 0.0101 Loss_G: 0.0427 D(G(z)): 0.9582
[0/5][150/469] Loss_D: 0.0059 Loss_G: 0.0252 D(G(z)): 0.9751
[0/5][200/469] Loss_D: 0.0038 Loss_G: 0.0169 D(G(z)): 0.9833
[0/5][250/469] Loss_D: 0.0027 Loss_G: 0.0121 D(G(z)): 0.9880
[0/5][300/469] Loss_D: 0.0020 Loss_G: 0.0092 D(G(z)): 0.9909
[0/5][350/469] Loss_D: 0.0016 Loss_G: 0.0072 D(G(z)): 0.9928
[0/5][400/469] Loss_D: 0.0013 Loss_G: 0.0059 D(G(z)): 0.9941
[0/5][450/469] Loss_D: 0.0011 Loss_G: 0.0047 D(G(z)): 0.9953
[1/5][0/469] Loss_D: 0.0010 Loss_G: 0.0045 D(G(z)): 0.9955
[1/5][50/469] Loss_D: 0.0008 Loss_G: 0.0037 D(G(z)): 0.9963
[1/5][100/469] Loss_D: 0.0007 Loss_G: 0.0033 D(G(z)): 0.9967
[1/5][150/469] Loss_D: 0.0006 Loss_G: 0.0029 D(G(z)): 0.9971
[1/5][200/469] Loss_D: 0.0005 Loss_G: 0.0026 D(G(z)): 0.9974
[1/5][250/469] Loss_D: 0.0005 Loss_G: 0.0023 D(G(z)): 0.9977
[1/5][300/469] Loss_D: 0.0004 Loss_G: 0.0021 D(G(z)): 0.9979
[1/5][350/469] Loss_D: 0.0004 Loss_G: 0.0019 D(G(z)): 0.9981
[1/5][400/469] Loss_D: 0.0003 Loss_G: 0.0017 D(G(z)): 0.9983
[1/5][450/469] Loss_D: 0.0003 Loss_G: 0.0016 D(G(z)): 0.9984
[2/5][0/469] Loss_D: 0.0003 Loss_G: 0.0015 D(G(z)): 0.9985
[2/5][50/469] Loss_D: 0.0003 Loss_G: 0.0014 D(G(z)): 0.9986
[2/5][100/469] Loss_D: 0.0002 Loss_G: 0.0013 D(G(z)): 0.9987
[2/5][150/469] Loss_D: 0.0002 Loss_G: 0.0012 D(G(z)): 0.9988
[2/5][200/469] Loss_D: 0.0002 Loss_G: 0.0011 D(G(z)): 0.9989
[2/5][250/469] Loss_D: 0.0002 Loss_G: 0.0010 D(G(z)): 0.9990
[2/5][300/469] Loss_D: 0.0002 Loss_G: 0.0010 D(G(z)): 0.9990
[2/5][350/469] Loss_D: 0.0002 Loss_G: 0.0009 D(G(z)): 0.9991
[2/5][400/469] Loss_D: 0.0002 Loss_G: 0.0009 D(G(z)): 0.9991
[2/5][450/469] Loss_D: 0.0001 Loss_G: 0.0008 D(G(z)): 0.9992
[3/5][0/469] Loss_D: 0.0001 Loss_G: 0.0008 D(G(z)): 0.9992
[3/5][50/469] Loss_D: 0.0001 Loss_G: 0.0008 D(G(z)): 0.9992
[3/5][100/469] Loss_D: 0.0001 Loss_G: 0.0007 D(G(z)): 0.9993
[3/5][150/469] Loss_D: 0.0001 Loss_G: 0.0007 D(G(z)): 0.9993
[3/5][200/469] Loss_D: 0.0001 Loss_G: 0.0007 D(G(z)): 0.9993
[3/5][250/469] Loss_D: 0.0001 Loss_G: 0.0006 D(G(z)): 0.9994
[3/5][300/469] Loss_D: 0.0001 Loss_G: 0.0006 D(G(z)): 0.9994
[3/5][350/469] Loss_D: 0.0001 Loss_G: 0.0006 D(G(z)): 0.9994
[3/5][400/469] Loss_D: 0.0001 Loss_G: 0.0005 D(G(z)): 0.9995
[3/5][450/469] Loss_D: 0.0001 Loss_G: 0.0005 D(G(z)): 0.9995
[4/5][0/469] Loss_D: 0.0001 Loss_G: 0.0005 D(G(z)): 0.9995
[4/5][50/469] Loss_D: 0.0001 Loss_G: 0.0005 D(G(z)): 0.9995
[4/5][100/469] Loss_D: 0.0001 Loss_G: 0.0005 D(G(z)): 0.9995
[4/5][150/469] Loss_D: 0.0001 Loss_G: 0.0004 D(G(z)): 0.9996
[4/5][200/469] Loss_D: 0.0001 Loss_G: 0.0004 D(G(z)): 0.9996
[4/5][250/469] Loss_D: 0.0001 Loss_G: 0.0004 D(G(z)): 0.9996
[4/5][300/469] Loss_D: 0.0001 Loss_G: 0.0004 D(G(z)): 0.9996
[4/5][350/469] Loss_D: 0.0001 Loss_G: 0.0004 D(G(z)): 0.9996
[4/5][400/469] Loss_D: 0.0001 Loss_G: 0.0004 D(G(z)): 0.9996
[4/5][450/469] Loss_D: 0.0001 Loss_G: 0.0003 D(G(z)): 0.9997
# plot the loss for generator and discriminator
plot_GAN_loss([G_losses, D_losses], ["G", "D"])
# Grab a batch of real images from the dataloader
plot_real_fake_images(next(iter(dataloader)), img_list)
# re-initilizate networks for the generator and discrimintor.
netG = initialize_net(Generator, weights_init, device, ngpu)
netD = initialize_net(Discriminator, weights_init, device, ngpu)
# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
# Training Loop
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0
print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
# For each batch in the dataloader
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
## Train with all-real batch
netD.zero_grad()
# Format batch
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size,), real_label, device=device)
# Forward pass real batch through D
output = netD(real_cpu).view(-1)
# Calculate loss on all-real batch
errD_real = criterion(output, label)
# Calculate gradients for D in backward pass
errD_real.backward()
D_x = output.mean().item()
## Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size, nz, 1, 1, device=device)
# Generate fake image batch with G
fake = netG(noise)
label.fill_(fake_label)
# Classify all fake batch with D
output = netD(fake.detach()).view(-1)
# Calculate D's loss on the all-fake batch
errD_fake = criterion(output, label)
# Calculate the gradients for this batch
errD_fake.backward()
D_G_z1 = output.mean().item()
# Add the gradients from the all-real and all-fake batches
errD = errD_real + errD_fake
# Update D
optimizerD.step()
############################
# (2) Update G network
###########################
################################ YOUR CODE ################################
netG.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of all-fake batch through D
output = netD(fake).view(-1)
# Calculate G's loss based on this output
errG = criterion(output, 1 - label) # Modify the label to 1 - real_label
# Calculate gradients for G
errG.backward()
D_G_z2 = output.mean().item()
# Update G
optimizerG.step()
############################ END YOUR CODE ##############################
# Output training stats
if i % 50 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
% (epoch, num_epochs, i, len(dataloader),
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
# Save Losses for plotting later
G_losses.append(errG.item())
D_losses.append(errD.item())
# Check how the generator is doing by saving G's output on fixed_noise
if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
iters += 1
Generator(
(main): Sequential(
(0): ConvTranspose2d(100, 32, kernel_size=(4, 4), stride=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(7): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU(inplace=True)
(9): ConvTranspose2d(8, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(10): Tanh()
)
)
Discriminator(
(main): Sequential(
(0): Conv2d(1, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2, inplace=True)
(5): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): LeakyReLU(negative_slope=0.2, inplace=True)
(8): Conv2d(32, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
(9): Sigmoid()
)
)
Starting Training Loop...
[0/5][0/469] Loss_D: 1.4630 Loss_G: 0.7272 D(x): 0.5020 D(G(z)): 0.5300 / 0.5109
[0/5][50/469] Loss_D: 0.0497 Loss_G: 0.0176 D(x): 0.9689 D(G(z)): 0.0178 / 0.0174
[0/5][100/469] Loss_D: 0.0173 Loss_G: 0.0061 D(x): 0.9890 D(G(z)): 0.0062 / 0.0061
[0/5][150/469] Loss_D: 0.0087 Loss_G: 0.0032 D(x): 0.9946 D(G(z)): 0.0032 / 0.0032
[0/5][200/469] Loss_D: 0.0053 Loss_G: 0.0020 D(x): 0.9968 D(G(z)): 0.0021 / 0.0020
[0/5][250/469] Loss_D: 0.0039 Loss_G: 0.0014 D(x): 0.9976 D(G(z)): 0.0014 / 0.0014
[0/5][300/469] Loss_D: 0.0028 Loss_G: 0.0011 D(x): 0.9983 D(G(z)): 0.0011 / 0.0011
[0/5][350/469] Loss_D: 0.0023 Loss_G: 0.0009 D(x): 0.9986 D(G(z)): 0.0009 / 0.0009
[0/5][400/469] Loss_D: 0.0019 Loss_G: 0.0007 D(x): 0.9988 D(G(z)): 0.0007 / 0.0007
[0/5][450/469] Loss_D: 0.0015 Loss_G: 0.0006 D(x): 0.9991 D(G(z)): 0.0006 / 0.0006
[1/5][0/469] Loss_D: 0.0015 Loss_G: 0.0005 D(x): 0.9991 D(G(z)): 0.0005 / 0.0005
[1/5][50/469] Loss_D: 0.0013 Loss_G: 0.0005 D(x): 0.9992 D(G(z)): 0.0005 / 0.0005
[1/5][100/469] Loss_D: 0.0010 Loss_G: 0.0004 D(x): 0.9994 D(G(z)): 0.0004 / 0.0004
[1/5][150/469] Loss_D: 0.0009 Loss_G: 0.0003 D(x): 0.9994 D(G(z)): 0.0003 / 0.0003
[1/5][200/469] Loss_D: 0.0008 Loss_G: 0.0003 D(x): 0.9995 D(G(z)): 0.0003 / 0.0003
[1/5][250/469] Loss_D: 0.0007 Loss_G: 0.0003 D(x): 0.9996 D(G(z)): 0.0003 / 0.0003
[1/5][300/469] Loss_D: 0.0006 Loss_G: 0.0002 D(x): 0.9996 D(G(z)): 0.0002 / 0.0002
[1/5][350/469] Loss_D: 0.0005 Loss_G: 0.0002 D(x): 0.9997 D(G(z)): 0.0002 / 0.0002
[1/5][400/469] Loss_D: 0.0005 Loss_G: 0.0002 D(x): 0.9997 D(G(z)): 0.0002 / 0.0002
[1/5][450/469] Loss_D: 0.0004 Loss_G: 0.0002 D(x): 0.9997 D(G(z)): 0.0002 / 0.0002
[2/5][0/469] Loss_D: 0.0004 Loss_G: 0.0002 D(x): 0.9997 D(G(z)): 0.0002 / 0.0002
[2/5][50/469] Loss_D: 0.0004 Loss_G: 0.0002 D(x): 0.9998 D(G(z)): 0.0002 / 0.0002
[2/5][100/469] Loss_D: 0.0004 Loss_G: 0.0001 D(x): 0.9998 D(G(z)): 0.0001 / 0.0001
[2/5][150/469] Loss_D: 0.0003 Loss_G: 0.0001 D(x): 0.9998 D(G(z)): 0.0001 / 0.0001
[2/5][200/469] Loss_D: 0.0003 Loss_G: 0.0001 D(x): 0.9998 D(G(z)): 0.0001 / 0.0001
[2/5][250/469] Loss_D: 0.0003 Loss_G: 0.0001 D(x): 0.9998 D(G(z)): 0.0001 / 0.0001
[2/5][300/469] Loss_D: 0.0003 Loss_G: 0.0001 D(x): 0.9998 D(G(z)): 0.0001 / 0.0001
[2/5][350/469] Loss_D: 0.0003 Loss_G: 0.0001 D(x): 0.9998 D(G(z)): 0.0001 / 0.0001
[2/5][400/469] Loss_D: 0.0002 Loss_G: 0.0001 D(x): 0.9999 D(G(z)): 0.0001 / 0.0001
[2/5][450/469] Loss_D: 0.0002 Loss_G: 0.0001 D(x): 0.9999 D(G(z)): 0.0001 / 0.0001
[3/5][0/469] Loss_D: 0.0002 Loss_G: 0.0001 D(x): 0.9999 D(G(z)): 0.0001 / 0.0001
[3/5][50/469] Loss_D: 0.0002 Loss_G: 0.0001 D(x): 0.9999 D(G(z)): 0.0001 / 0.0001
[3/5][100/469] Loss_D: 0.0002 Loss_G: 0.0001 D(x): 0.9999 D(G(z)): 0.0001 / 0.0001
[3/5][150/469] Loss_D: 0.0002 Loss_G: 0.0001 D(x): 0.9999 D(G(z)): 0.0001 / 0.0001
[3/5][200/469] Loss_D: 0.0002 Loss_G: 0.0001 D(x): 0.9999 D(G(z)): 0.0001 / 0.0001
[3/5][250/469] Loss_D: 0.0001 Loss_G: 0.0001 D(x): 0.9999 D(G(z)): 0.0001 / 0.0001
[3/5][300/469] Loss_D: 0.0001 Loss_G: 0.0001 D(x): 0.9999 D(G(z)): 0.0001 / 0.0001
[3/5][350/469] Loss_D: 0.0001 Loss_G: 0.0001 D(x): 0.9999 D(G(z)): 0.0001 / 0.0001
[3/5][400/469] Loss_D: 0.0001 Loss_G: 0.0001 D(x): 0.9999 D(G(z)): 0.0001 / 0.0001
[3/5][450/469] Loss_D: 0.0001 Loss_G: 0.0001 D(x): 0.9999 D(G(z)): 0.0001 / 0.0001
[4/5][0/469] Loss_D: 0.0001 Loss_G: 0.0000 D(x): 0.9999 D(G(z)): 0.0000 / 0.0000
[4/5][50/469] Loss_D: 0.0001 Loss_G: 0.0000 D(x): 0.9999 D(G(z)): 0.0000 / 0.0000
[4/5][100/469] Loss_D: 0.0001 Loss_G: 0.0000 D(x): 0.9999 D(G(z)): 0.0000 / 0.0000
[4/5][150/469] Loss_D: 0.0001 Loss_G: 0.0000 D(x): 0.9999 D(G(z)): 0.0000 / 0.0000
[4/5][200/469] Loss_D: 0.0001 Loss_G: 0.0000 D(x): 0.9999 D(G(z)): 0.0000 / 0.0000
[4/5][250/469] Loss_D: 0.0001 Loss_G: 0.0000 D(x): 0.9999 D(G(z)): 0.0000 / 0.0000
[4/5][300/469] Loss_D: 0.0001 Loss_G: 0.0000 D(x): 1.0000 D(G(z)): 0.0000 / 0.0000
[4/5][350/469] Loss_D: 0.0001 Loss_G: 0.0000 D(x): 1.0000 D(G(z)): 0.0000 / 0.0000
[4/5][400/469] Loss_D: 0.0001 Loss_G: 0.0000 D(x): 1.0000 D(G(z)): 0.0000 / 0.0000
[4/5][450/469] Loss_D: 0.0001 Loss_G: 0.0000 D(x): 1.0000 D(G(z)): 0.0000 / 0.0000
# plot the loss for generator and discriminator
plot_GAN_loss([G_losses, D_losses], ["G", "D"])
# Grab a batch of real images from the dataloader
plot_real_fake_images(next(iter(dataloader)), img_list)
initialize_net provided in Task 1.0 to initialize the generator and discriminator function without weight initialization (HINT: There is no need to modify the code for initialize_net function).################################ YOUR CODE ################################
netG_woinit = initialize_net(Generator, None, device, ngpu)
netD_woinit = initialize_net(Discriminator, None, device, ngpu)
########################### END YOUR CODE ###############################
Generator(
(main): Sequential(
(0): ConvTranspose2d(100, 32, kernel_size=(4, 4), stride=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(7): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU(inplace=True)
(9): ConvTranspose2d(8, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(10): Tanh()
)
)
Discriminator(
(main): Sequential(
(0): Conv2d(1, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2, inplace=True)
(5): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): LeakyReLU(negative_slope=0.2, inplace=True)
(8): Conv2d(32, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
(9): Sigmoid()
)
)
# Setup Adam optimizers for both G and D
optimizerD_woinit = optim.Adam(netD_woinit.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG_woinit = optim.Adam(netG_woinit.parameters(), lr=lr, betas=(beta1, 0.999))
# Training Loop
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0
print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
# For each batch in the dataloader
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
## Train with all-real batch
netD_woinit.zero_grad()
# Format batch
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size,), real_label, device=device)
# Forward pass real batch through D
output = netD_woinit(real_cpu).view(-1)
# Calculate loss on all-real batch
errD_real = criterion(output, label)
# Calculate gradients for D in backward pass
errD_real.backward()
D_x = output.mean().item()
## Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size, nz, 1, 1, device=device)
# Generate fake image batch with G
fake = netG_woinit(noise)
label.fill_(fake_label)
# Classify all fake batch with D
output = netD_woinit(fake.detach()).view(-1)
# Calculate D's loss on the all-fake batch
errD_fake = criterion(output, label)
# Calculate the gradients for this batch
errD_fake.backward()
D_G_z1 = output.mean().item()
# Add the gradients from the all-real and all-fake batches
errD = errD_real + errD_fake
# Update D
optimizerD_woinit.step()
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG_woinit.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of all-fake batch through D
output = netD_woinit(fake).view(-1)
# Calculate G's loss based on this output
errG = criterion(output, label)
# Calculate gradients for G
errG.backward()
D_G_z2 = output.mean().item()
# Update G
optimizerG_woinit.step()
# Output training stats
if i % 50 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
% (epoch, num_epochs, i, len(dataloader),
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
# Save Losses for plotting later
G_losses.append(errG.item())
D_losses.append(errD.item())
# Check how the generator is doing by saving G's output on fixed_noise
if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
with torch.no_grad():
fake = netG_woinit(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
iters += 1
Starting Training Loop... [0/5][0/469] Loss_D: 1.4309 Loss_G: 0.6980 D(x): 0.5045 D(G(z)): 0.5139 / 0.5034 [0/5][50/469] Loss_D: 0.6671 Loss_G: 1.1116 D(x): 0.8095 D(G(z)): 0.3604 / 0.3361 [0/5][100/469] Loss_D: 0.5063 Loss_G: 1.5368 D(x): 0.8209 D(G(z)): 0.2548 / 0.2249 [0/5][150/469] Loss_D: 0.3605 Loss_G: 1.9759 D(x): 0.8682 D(G(z)): 0.1886 / 0.1460 [0/5][200/469] Loss_D: 0.2287 Loss_G: 2.6202 D(x): 0.9105 D(G(z)): 0.1204 / 0.0839 [0/5][250/469] Loss_D: 0.1573 Loss_G: 3.0957 D(x): 0.9259 D(G(z)): 0.0731 / 0.0499 [0/5][300/469] Loss_D: 0.1034 Loss_G: 3.6207 D(x): 0.9534 D(G(z)): 0.0503 / 0.0310 [0/5][350/469] Loss_D: 0.0520 Loss_G: 4.1089 D(x): 0.9787 D(G(z)): 0.0292 / 0.0191 [0/5][400/469] Loss_D: 0.0590 Loss_G: 4.1156 D(x): 0.9748 D(G(z)): 0.0310 / 0.0194 [0/5][450/469] Loss_D: 0.0674 Loss_G: 4.1557 D(x): 0.9686 D(G(z)): 0.0301 / 0.0190 [1/5][0/469] Loss_D: 0.0380 Loss_G: 4.6593 D(x): 0.9861 D(G(z)): 0.0232 / 0.0126 [1/5][50/469] Loss_D: 0.0318 Loss_G: 4.6841 D(x): 0.9870 D(G(z)): 0.0182 / 0.0118 [1/5][100/469] Loss_D: 0.0590 Loss_G: 3.9095 D(x): 0.9652 D(G(z)): 0.0208 / 0.0255 [1/5][150/469] Loss_D: 0.0580 Loss_G: 4.1443 D(x): 0.9721 D(G(z)): 0.0284 / 0.0193 [1/5][200/469] Loss_D: 0.0372 Loss_G: 4.4027 D(x): 0.9767 D(G(z)): 0.0128 / 0.0148 [1/5][250/469] Loss_D: 0.0646 Loss_G: 4.1671 D(x): 0.9696 D(G(z)): 0.0315 / 0.0212 [1/5][300/469] Loss_D: 0.0592 Loss_G: 4.2155 D(x): 0.9733 D(G(z)): 0.0304 / 0.0176 [1/5][350/469] Loss_D: 0.0674 Loss_G: 4.1542 D(x): 0.9693 D(G(z)): 0.0338 / 0.0218 [1/5][400/469] Loss_D: 0.0333 Loss_G: 4.5988 D(x): 0.9876 D(G(z)): 0.0202 / 0.0133 [1/5][450/469] Loss_D: 0.0606 Loss_G: 3.8784 D(x): 0.9642 D(G(z)): 0.0193 / 0.0270 [2/5][0/469] Loss_D: 0.1381 Loss_G: 5.9040 D(x): 0.9846 D(G(z)): 0.1098 / 0.0039 [2/5][50/469] Loss_D: 2.3062 Loss_G: 0.4169 D(x): 0.2014 D(G(z)): 0.0002 / 0.6897 [2/5][100/469] Loss_D: 0.0758 Loss_G: 4.2113 D(x): 0.9589 D(G(z)): 0.0290 / 0.0197 [2/5][150/469] Loss_D: 0.0946 Loss_G: 3.7103 D(x): 0.9509 D(G(z)): 0.0386 / 0.0341 [2/5][200/469] Loss_D: 0.1628 Loss_G: 4.1826 D(x): 0.9775 D(G(z)): 0.1241 / 0.0209 [2/5][250/469] Loss_D: 0.1252 Loss_G: 4.2878 D(x): 0.9813 D(G(z)): 0.0966 / 0.0196 [2/5][300/469] Loss_D: 0.0500 Loss_G: 4.9315 D(x): 0.9871 D(G(z)): 0.0354 / 0.0113 [2/5][350/469] Loss_D: 0.1430 Loss_G: 3.3743 D(x): 0.9507 D(G(z)): 0.0825 / 0.0455 [2/5][400/469] Loss_D: 0.0754 Loss_G: 3.2040 D(x): 0.9736 D(G(z)): 0.0451 / 0.0588 [2/5][450/469] Loss_D: 0.0879 Loss_G: 3.8815 D(x): 0.9695 D(G(z)): 0.0535 / 0.0255 [3/5][0/469] Loss_D: 0.0878 Loss_G: 3.8022 D(x): 0.9468 D(G(z)): 0.0226 / 0.0317 [3/5][50/469] Loss_D: 0.1274 Loss_G: 4.7143 D(x): 0.9537 D(G(z)): 0.0662 / 0.0172 [3/5][100/469] Loss_D: 0.3286 Loss_G: 3.3465 D(x): 0.8949 D(G(z)): 0.1649 / 0.0547 [3/5][150/469] Loss_D: 0.1000 Loss_G: 3.9210 D(x): 0.9358 D(G(z)): 0.0271 / 0.0282 [3/5][200/469] Loss_D: 0.1147 Loss_G: 3.7722 D(x): 0.9474 D(G(z)): 0.0516 / 0.0361 [3/5][250/469] Loss_D: 0.2233 Loss_G: 1.8276 D(x): 0.8400 D(G(z)): 0.0210 / 0.2159 [3/5][300/469] Loss_D: 0.1482 Loss_G: 3.5661 D(x): 0.9354 D(G(z)): 0.0692 / 0.0415 [3/5][350/469] Loss_D: 0.1571 Loss_G: 2.5693 D(x): 0.9299 D(G(z)): 0.0717 / 0.1104 [3/5][400/469] Loss_D: 0.2600 Loss_G: 3.1801 D(x): 0.8795 D(G(z)): 0.1025 / 0.0640 [3/5][450/469] Loss_D: 0.1917 Loss_G: 2.7638 D(x): 0.8709 D(G(z)): 0.0338 / 0.0953 [4/5][0/469] Loss_D: 0.1196 Loss_G: 3.5909 D(x): 0.9598 D(G(z)): 0.0712 / 0.0404 [4/5][50/469] Loss_D: 0.1432 Loss_G: 3.6147 D(x): 0.9449 D(G(z)): 0.0768 / 0.0398 [4/5][100/469] Loss_D: 0.2123 Loss_G: 3.1947 D(x): 0.9158 D(G(z)): 0.0994 / 0.0602 [4/5][150/469] Loss_D: 0.1978 Loss_G: 3.4466 D(x): 0.8698 D(G(z)): 0.0369 / 0.0533 [4/5][200/469] Loss_D: 0.2423 Loss_G: 2.9390 D(x): 0.8847 D(G(z)): 0.0856 / 0.0834 [4/5][250/469] Loss_D: 0.3323 Loss_G: 1.9276 D(x): 0.7891 D(G(z)): 0.0352 / 0.2081 [4/5][300/469] Loss_D: 0.3603 Loss_G: 3.6882 D(x): 0.9622 D(G(z)): 0.2365 / 0.0487 [4/5][350/469] Loss_D: 0.2490 Loss_G: 3.3203 D(x): 0.9410 D(G(z)): 0.1494 / 0.0631 [4/5][400/469] Loss_D: 0.1772 Loss_G: 3.6082 D(x): 0.9259 D(G(z)): 0.0828 / 0.0494 [4/5][450/469] Loss_D: 0.2115 Loss_G: 3.1421 D(x): 0.9218 D(G(z)): 0.1078 / 0.0764
# plot the loss for generator and discriminator
plot_GAN_loss([G_losses, D_losses], ["G", "D"])
# Grab a batch of real images from the dataloader
plot_real_fake_images(next(iter(dataloader)), img_list)
Wasserstein GAN (WGAN) is an alternative training strategy to traditional GAN. WGAN may provide more stable learning and may avoid problems faced in traditional GAN training like mode collapse and vanishing gradient. We will not go through the whole derivation of this algorithm but if interested, you can find more details in the arXiv paper above and Prof. Inouye's lecture notes on Wasserstein GANs from ECE 570.
The objective function of WGAN is still a min-max but with a different objective function: $$ \min_G \max_D \mathbb{E}_{p_{data}}[D(x)] - \mathbb{E}_{p_z}[D(G(z))] \,, $$ where $D$ must be a 1-Lipschitz function (rather than a classifier as in regular GANs) and $p_z$ is a standard normal distribution. Notice the similarities and differences with the original GAN objective: $$ \min_G \max_D \mathbb{E}_{p_{data}}[\log D(x)] + \mathbb{E}_{p_z}[\log (1- D(G(z)))] \,, $$ where $D$ is a classifier. Note in practice the WGAN paper uses multiple discriminators (also called "critics") so they use multiple $D$s during training.
We will not go through the derivation but one approximation algorithm for optimizing the WGAN objective is to apply weight clipping to all the weights, i.e., enforce that their absolute value is smaller than some constant $c$. The full pseudo-algorithm can be found on slide 17 in these slides on WGAN or in the original paper.
Sigmoid layer)lr_rms (which we set to 5e-4, which is larger than the rate in the paper but works better for our purposes).torch.Tensor.clamp_() function to clip the parameter values. You will need to do this for all parameters of the discriminator. See algorithm for when to do this.The objective function is different for WGAN. It is simply the mean of the discriminator/critic for real data minus fake data, i.e., for WGAN the value function is $ V(D,G)=\mathbb{E}_{p_{x \sim data}}[D(x)] - \mathbb{E}_{z \sim p_z}[D(G(z))]$. The expectations can be approximated by empirical expectations (i.e., average over samples in a batch).
Note that optimizers always assume gradient descent so when optimizing $D$, the loss will be negative of the value function, i.e.,$-V(D,G)$, which is equivalent to gradient ascent on $V(D,G)$.
For clamping the parameters, you will have to loop over all parameters of netD.parameters() and clamp each one in place.
For the algorithm on slide 17 of the WGAN slides, the function denoted by $f_w$ is equivalent to our $D$ and the function denoted $g_\theta$ is equivalent to our $G$. The parameters of netD are denoted by $w$. Finally, note that in Line 6, it shows gradient ascent (which can be implemented in torch as discussed above), while in line 11, it is using standard gradient descent.
class Discriminator_WGAN(nn.Module):
def __init__(self, ngpu):
super(Discriminator_WGAN, self).__init__()
self.ngpu = ngpu
################################ YOUR CODE ################################
# input is (nc) x 32 x 32
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 16 x 16
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 8 x 8
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 4 x 4
nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
# state size. (ndf*4) x 1 x 1
#nn.Sigmoid() # Produce probability
########################### END YOUR CODE ################################
def forward(self, input):
return self.main(input)
netG = initialize_net(Generator, weights_init, device, ngpu)
netD = initialize_net(Discriminator_WGAN, weights_init, device, ngpu)
############################ YOUR CODE ############################
# Setup RMSprop optimizers for both netG and netD with given learning rate as `lr_rms`
optimizerD = optim.RMSprop(netD.parameters(), lr=lr_rms)
optimizerG = optim.RMSprop(netG.parameters(), lr=lr_rms)
######################## # END YOUR CODE ##########################
# Training Loop
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
n_critic = 5
c = 0.01
dataloader_iter = iter(dataloader)
print("Starting Training Loop...")
num_iters = 1000
for iters in range(num_iters):
###########################################################################
# (1) Train Discriminator more: minimize -(mean(D(real))-mean(D(fake)))
###########################################################################
for p in netD.parameters():
p.requires_grad = True
for idx_critic in range(n_critic):
netD.zero_grad()
try:
data = next(dataloader_iter)
except StopIteration:
dataloader_iter = iter(dataloader)
data = next(dataloader_iter)
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
D_real = netD(real_cpu).view(-1)
noise = torch.randn(b_size, nz, 1, 1, device=device)
fake = netG(noise)
D_fake = netD(fake).view(-1)
############################ YOUR CODE ############################
# Define your loss function for variable `D_loss`
D_loss = -(D_real.mean() - D_fake.mean())
# Backpropagate the loss function and update the optimizer
D_loss.backward()
optimizerD.step()
# Clip the D network parameters to be within -c and c by using `clamp_()` function
# Note that if all weights are bounded, then the Lipschitz constant is bounded.
for p in netD.parameters():
p.data.clamp_(-c, c)
######################## # END YOUR CODE ##########################
###########################################################################
# (2) Update G network: minimize -mean(D(fake)) (Update only once in 5 epochs)
###########################################################################
for p in netD.parameters():
p.requires_grad = False
netG.zero_grad()
noise = torch.randn(b_size, nz, 1, 1, device=device)
fake = netG(noise)
D_fake = netD(fake).view(-1)
################################ YOUR CODE ################################
# Define your loss function for variable `G_loss`
G_loss = -D_fake.mean()
# Backpropagate the loss function and upate the optimizer
G_loss.backward()
optimizerG.step()
############################# END YOUR CODE ##############################
# Output training stats
if iters % 10 == 0:
print('[%4d/%4d] Loss_D: %6.4f Loss_G: %6.4f'
% (iters, num_iters, D_loss.item(), G_loss.item()))
# Save Losses for plotting later
G_losses.append(G_loss.item())
D_losses.append(D_loss.item())
# Check how the generator is doing by saving G's output on fixed_noise
if (iters % 100 == 0):
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
Starting Training Loop... [ 0/1000] Loss_D: -0.0079 Loss_G: -0.4973 [ 10/1000] Loss_D: -0.0105 Loss_G: -0.4956 [ 20/1000] Loss_D: -0.0107 Loss_G: -0.4949 [ 30/1000] Loss_D: -0.0110 Loss_G: -0.4948 [ 40/1000] Loss_D: -0.0107 Loss_G: -0.4949 [ 50/1000] Loss_D: -0.0099 Loss_G: -0.4952 [ 60/1000] Loss_D: -0.0085 Loss_G: -0.4959 [ 70/1000] Loss_D: -0.0063 Loss_G: -0.4960 [ 80/1000] Loss_D: -0.0064 Loss_G: -0.4966 [ 90/1000] Loss_D: -0.0060 Loss_G: -0.4973 [ 100/1000] Loss_D: -0.0057 Loss_G: -0.4975 [ 110/1000] Loss_D: -0.0048 Loss_G: -0.4970 [ 120/1000] Loss_D: -0.0049 Loss_G: -0.4985 [ 130/1000] Loss_D: -0.0043 Loss_G: -0.4988 [ 140/1000] Loss_D: -0.0038 Loss_G: -0.4975 [ 150/1000] Loss_D: -0.0038 Loss_G: -0.4988 [ 160/1000] Loss_D: -0.0033 Loss_G: -0.4980 [ 170/1000] Loss_D: -0.0032 Loss_G: -0.4999 [ 180/1000] Loss_D: -0.0031 Loss_G: -0.5000 [ 190/1000] Loss_D: -0.0029 Loss_G: -0.4979 [ 200/1000] Loss_D: -0.0028 Loss_G: -0.4998 [ 210/1000] Loss_D: -0.0026 Loss_G: -0.4998 [ 220/1000] Loss_D: -0.0019 Loss_G: -0.4987 [ 230/1000] Loss_D: -0.0021 Loss_G: -0.4995 [ 240/1000] Loss_D: -0.0019 Loss_G: -0.4990 [ 250/1000] Loss_D: -0.0021 Loss_G: -0.4995 [ 260/1000] Loss_D: -0.0018 Loss_G: -0.5001 [ 270/1000] Loss_D: -0.0017 Loss_G: -0.4988 [ 280/1000] Loss_D: -0.0015 Loss_G: -0.5005 [ 290/1000] Loss_D: -0.0016 Loss_G: -0.4997 [ 300/1000] Loss_D: -0.0015 Loss_G: -0.4983 [ 310/1000] Loss_D: -0.0013 Loss_G: -0.5002 [ 320/1000] Loss_D: -0.0014 Loss_G: -0.4997 [ 330/1000] Loss_D: -0.0012 Loss_G: -0.4993 [ 340/1000] Loss_D: -0.0011 Loss_G: -0.4995 [ 350/1000] Loss_D: -0.0012 Loss_G: -0.5002 [ 360/1000] Loss_D: -0.0013 Loss_G: -0.4993 [ 370/1000] Loss_D: -0.0012 Loss_G: -0.5001 [ 380/1000] Loss_D: -0.0011 Loss_G: -0.4994 [ 390/1000] Loss_D: -0.0012 Loss_G: -0.4996 [ 400/1000] Loss_D: -0.0010 Loss_G: -0.4984 [ 410/1000] Loss_D: -0.0010 Loss_G: -0.4997 [ 420/1000] Loss_D: -0.0011 Loss_G: -0.4975 [ 430/1000] Loss_D: -0.0009 Loss_G: -0.4998 [ 440/1000] Loss_D: -0.0010 Loss_G: -0.5011 [ 450/1000] Loss_D: -0.0011 Loss_G: -0.4991 [ 460/1000] Loss_D: -0.0009 Loss_G: -0.4997 [ 470/1000] Loss_D: -0.0008 Loss_G: -0.4999 [ 480/1000] Loss_D: -0.0010 Loss_G: -0.5009 [ 490/1000] Loss_D: -0.0010 Loss_G: -0.4979 [ 500/1000] Loss_D: -0.0009 Loss_G: -0.4978 [ 510/1000] Loss_D: -0.0008 Loss_G: -0.4997 [ 520/1000] Loss_D: -0.0008 Loss_G: -0.4995 [ 530/1000] Loss_D: -0.0008 Loss_G: -0.4989 [ 540/1000] Loss_D: -0.0007 Loss_G: -0.4995 [ 550/1000] Loss_D: -0.0007 Loss_G: -0.5009 [ 560/1000] Loss_D: -0.0008 Loss_G: -0.4993 [ 570/1000] Loss_D: -0.0008 Loss_G: -0.5010 [ 580/1000] Loss_D: -0.0006 Loss_G: -0.5001 [ 590/1000] Loss_D: -0.0007 Loss_G: -0.5010 [ 600/1000] Loss_D: -0.0008 Loss_G: -0.5007 [ 610/1000] Loss_D: -0.0007 Loss_G: -0.5006 [ 620/1000] Loss_D: -0.0007 Loss_G: -0.5011 [ 630/1000] Loss_D: -0.0008 Loss_G: -0.5004 [ 640/1000] Loss_D: -0.0007 Loss_G: -0.5006 [ 650/1000] Loss_D: -0.0007 Loss_G: -0.5003 [ 660/1000] Loss_D: -0.0007 Loss_G: -0.5007 [ 670/1000] Loss_D: -0.0008 Loss_G: -0.4994 [ 680/1000] Loss_D: -0.0005 Loss_G: -0.5004 [ 690/1000] Loss_D: -0.0006 Loss_G: -0.5001 [ 700/1000] Loss_D: -0.0006 Loss_G: -0.5006 [ 710/1000] Loss_D: -0.0006 Loss_G: -0.4996 [ 720/1000] Loss_D: -0.0006 Loss_G: -0.5004 [ 730/1000] Loss_D: -0.0007 Loss_G: -0.4997 [ 740/1000] Loss_D: -0.0006 Loss_G: -0.4997 [ 750/1000] Loss_D: -0.0007 Loss_G: -0.5010 [ 760/1000] Loss_D: -0.0006 Loss_G: -0.5006 [ 770/1000] Loss_D: -0.0005 Loss_G: -0.5008 [ 780/1000] Loss_D: -0.0006 Loss_G: -0.5006 [ 790/1000] Loss_D: -0.0006 Loss_G: -0.5012 [ 800/1000] Loss_D: -0.0005 Loss_G: -0.4999 [ 810/1000] Loss_D: -0.0007 Loss_G: -0.5003 [ 820/1000] Loss_D: -0.0005 Loss_G: -0.4998 [ 830/1000] Loss_D: -0.0004 Loss_G: -0.5000 [ 840/1000] Loss_D: -0.0004 Loss_G: -0.5004 [ 850/1000] Loss_D: -0.0005 Loss_G: -0.5000 [ 860/1000] Loss_D: -0.0006 Loss_G: -0.4998 [ 870/1000] Loss_D: -0.0006 Loss_G: -0.4997 [ 880/1000] Loss_D: -0.0005 Loss_G: -0.4994 [ 890/1000] Loss_D: -0.0005 Loss_G: -0.5006 [ 900/1000] Loss_D: -0.0005 Loss_G: -0.4991 [ 910/1000] Loss_D: -0.0006 Loss_G: -0.4996 [ 920/1000] Loss_D: -0.0005 Loss_G: -0.4998 [ 930/1000] Loss_D: -0.0006 Loss_G: -0.5006 [ 940/1000] Loss_D: -0.0005 Loss_G: -0.4998 [ 950/1000] Loss_D: -0.0006 Loss_G: -0.5010 [ 960/1000] Loss_D: -0.0005 Loss_G: -0.4991 [ 970/1000] Loss_D: -0.0004 Loss_G: -0.4989 [ 980/1000] Loss_D: -0.0004 Loss_G: -0.5001 [ 990/1000] Loss_D: -0.0005 Loss_G: -0.5000
# plot the loss for generator and discriminator
plot_GAN_loss([G_losses, D_losses], ["G", "D"])
# Grab a batch of real images from the dataloader
plot_real_fake_images(next(iter(dataloader)), img_list)
Use slide 19 in Lecture note for WGAN to implement WGAN-GP algorithm.
torch.autograd.grad. You will need to set:outputsinputsgrad_outputscreate_graph=True and retain_graph=True (because we want to backprop through this gradient calculation for the final objective.)grad_norm = torch.sqrt((grad**2).sum(1) + 1e-14) is a simple way to compute the norm.)Train the model with modified networks and visualize the results.
# Setup networks for WGAN-GP
netG = initialize_net(Generator, weights_init, device, ngpu)
netD = initialize_net(Discriminator_WGAN, weights_init, device, ngpu)
# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=5e-4, betas=(0.5, 0.9))
optimizerG = optim.Adam(netG.parameters(), lr=5e-4, betas=(0.5, 0.9))
# Training Loop
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
n_critic = 5
dataloader_iter = iter(dataloader)
print("Starting Training Loop...")
num_iters = 1000
for iters in range(num_iters):
###########################################################################
# (1) Train Discriminator more: minimize -(mean(D(real))-mean(D(fake)))+GP
###########################################################################
for p in netD.parameters():
p.requires_grad = True
for idx_critic in range(n_critic):
netD.zero_grad()
try:
data = next(dataloader_iter)
except StopIteration:
dataloader_iter = iter(dataloader)
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
D_real = netD(real_cpu).view(-1)
noise = torch.randn(b_size, nz, 1, 1, device=device)
fake = netG(noise)
D_fake = netD(fake).view(-1)
############################ YOUR CODE ############################
# Compute the gradient penalty term
# Define your loss function for variable `D_loss`
# Backpropagate the loss function and upate the optimizer
######################## # END YOUR CODE ##########################
###########################################################################
# (2) Update G network: minimize -mean(D(fake)) (Update only once in 5 epochs)
###########################################################################
for p in netD.parameters():
p.requires_grad = False
netG.zero_grad()
noise = torch.randn(b_size, nz, 1, 1, device=device)
fake = netG(noise)
D_fake = netD(fake).view(-1)
################################ YOUR CODE ################################
# Define your loss function for variable `G_loss`
# Backpropagate the loss function and upate the optimizer
############################# END YOUR CODE ##############################
# Output training stats
if iters % 10 == 0:
print('[%4d/%4d] Loss_D: %6.4f Loss_G: %6.4f'
% (iters, num_iters, D_loss.item(), G_loss.item()))
# Save Losses for plotting later
G_losses.append(G_loss.item())
D_losses.append(D_loss.item())
# Check how the generator is doing by saving G's output on fixed_noise
if (iters % 100 == 0):
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
# plot the loss for generator and discriminator
plot_GAN_loss([G_losses, D_losses], ["G", "D"])
# Grab a batch of real images from the dataloader
plot_real_fake_images(next(iter(dataloader)), img_list)